# -*- coding: utf-8 -*-
import pickle
import pandas as pd
import numpy as np
from scipy.stats import ttest_ind, f, kruskal
from finding_bi_star import clustering
from draw_h_r_diagram import plot_HRdiagram

def anova(X, alpha):
    '''
    若 test_along_row 意味着对同一个水平的测试是分布在行上的。
    例如，若 test_along_row = True：
    则表格的表头大致如下：
    水平（样品号）： 测试1、测试2、测试3、...、测试 n
    '''
    
    rows, columns = X.shape
    # 样本容量
    m = rows*columns
    # 组内方差
    SS2 = 0
    # 组间方差
    SS1 = 0
    # 总均值
    bar_bar_x = 0
    bar_x = []
    for i in range(rows):
        # 求组内样品均值
        x_ij = X.iloc[i, :].values
        bar_x_i= x_ij.mean()  
        # 组内方差
        SS2 += np.square((x_ij-bar_x_i)).sum()
        # 总均值
        bar_x.append(bar_x_i)

    # 将 list 转换为 np.array
    bar_x = np.array(bar_x)
    bar_bar_x = bar_x.mean()
    # 求组间方差
    SS1 = columns*np.square((bar_x - bar_bar_x)).sum()
    # 组间自由度
    dfn = rows-1
    # 组内自由度
    dfd = m-rows
    # 求组间均方
    MS1 = SS1/dfn
    # 求组内均方
    MS2 = SS2/dfd
    # 检验统计量 F
    F = MS1/MS2
    # F 在 p=0.95 时的值
    F_alpha = f.ppf(1-alpha, dfn, dfd)
    
    return F < F_alpha 
    
def cluster_param_select(data, \
                        data_bi, \
                        epsilon_list, \
                        min_samples_list, \
                        coef_list, \
                        cluster_targets=['pmRA', 'pmDE'],
                        alpha=0.1):
    '''
    data 是删除了 毕星团 后的数据
    筛选聚类最佳参数
    '''
    # 临时变量
    value_list = []
    params_list = [] 
    params_dict = dict()

    # 删除 RA 和 DE，以免影响聚类
    data.drop(['RA', 'DE'], axis=1, inplace=True)
    
    pmRA_target = data_bi['pmRA']
    pmDE_target = data_bi['pmDE']
    del data_bi
    
    # 产生参数网格。
    for eps in epsilon_list:
        for min_samples in min_samples_list:
            for coef in coef_list:
                labels = clustering(data, eps,\
                                    min_samples, coef, \
                                    cluster_targets)
                # 构建一个临时变量，用以筛选聚类参数
                data_tmp = data.copy()
                data_tmp['label'] = labels
                
                # 聚类簇
                clusters = np.unique(labels)
                # 聚类簇数量
                num_clusters = len(clusters)
                if num_clusters == 1:
                    # 只有一个簇，跳过
                    continue
                # 计算每一个聚类簇的个体数，构成一个字典
                clu_num_dict = dict()
                for label in labels:
                    clu_num_dict[label] = clu_num_dict.get(label, 0) \
                                            + 1

                # 主要簇的个体数：
                num_main_clu = sorted(clu_num_dict.items(), key = lambda x: x[1],\
                                    reverse=True)
                if num_main_clu == []:
                    # 聚类算法没有得出聚类，continue
                    continue
                
                # 个体数最多的聚类簇编号
                max_cluster = num_main_clu[0][0]
                
                max_num = num_main_clu[0][1]
                if max_cluster == -1:
                    # 包含个体数最多的簇不能是 -1（-1 是游离个体）
                    # 于是换老二
                    max_cluster = num_main_clu[1][0]         
                    max_num = num_main_clu[1][1]

                # 提取最多聚类簇下的个体
                data_tmp = data_tmp.loc[data_tmp['label']==max_cluster]
                pmRA = data_tmp['pmRA']
                pmDE = data_tmp['pmDE']
                
                
                if not anova(data_tmp[['pmRA', 'pmDE']], alpha):
                    # 没能通过 ANOVA 检验
                    continue
                
                # 没能通过T检验，则 continue
                _, p1 = ttest_ind(pmRA, pmRA_target)
                _, p2 = ttest_ind(pmDE, pmDE_target)
                if p1<alpha or p2<alpha:
                    # pmRA 和 pmDE 没能通过 T 检验
                    continue
                
                if -1 in clusters:
                    # 游离个体不能被当成簇
                    num_clusters -= 1

                # create list
                value_tuple = tuple((max_num, num_clusters))
                params_tuple = tuple((eps, min_samples, coef))

                params_dict[params_tuple] = value_tuple
                
    # 根据原则 1、原则 2 排序。
    max_param = sorted(params_dict.items(), \
                        key=lambda x: (x[1][0], x[1][1]), \
                        reverse=True)
    if max_param == []:
        print('参数网格全军覆没...')
        return 0,0,0
    
    # 返回最佳参数
    max_param = max_param[0]
    best_eps = max_param[0][0]
    best_min_samples = max_param[0][1]
    best_coef = max_param[0][2]

    best_max_num = max_param[1][0]
    best_num_cluster = max_param[1][1]
    # 输出结果
    print(f'聚类最佳参数为： \n epsilon: {best_eps}： \n min_samples: {best_min_samples}' +
                f'\n 系数： {coef} \n 簇最大个体数 {best_max_num}' +
                f'\n 簇数： {best_num_cluster}')
                
    return best_eps, best_min_samples, best_coef

def bi_star_stream(data, df, bi_star_HIP,
                    eps=0.5, min_samples=2, coef=10, \
                    cluster_targets=['pmRA', 'pmDE']
                    ):
    '''
    data 是提出了毕宿星团后的数据，类型为 dataframe
    df 是原始数据
    返回毕宿星流
    '''
    # 删除 RA 和 DE，以免影响聚类
    data.drop(['RA', 'DE'], axis=1, inplace=True)
    
    labels = clustering(data, eps=eps, \
                        min_samples=min_samples, \
                        coef=coef, 
                        cluster_targets=cluster_targets)
    
    # 找到数量最多的 label，并返回每个 label 最大的数量
    clu_and_num = dict()
    for label in labels:
        if label == -1:
            continue
        clu_and_num[label] = clu_and_num.get(label, 0) + 1
    
    clu_num_sort = sorted(clu_and_num.items(), key=lambda x:x[1], \
                          reverse=True)
    
    # 毕宿流星对应的聚类簇
    max_clu = clu_num_sort[0][0]
    print(f'毕宿星流共有 {max_clu} 颗星')
    
    df_without_bi = df.loc[~df['HIP'].isin(bi_star_HIP)].copy()
    df_without_bi['label'] = labels
    
    bi_stream = df_without_bi.loc[df_without_bi['label']==max_clu]
    return bi_stream
    

if __name__ == '__main__':
    scaler_model = r'../附件/scaler_model.pkl'
    HIP_path = r'../附件/bi_star_cluster_HIP.pkl'
    file_path = r'../附件/data_ml.pkl'
    path = r'../附件/星表数据.txt'
    
    # 读取 txt 文件
    df = pd.read_csv(path, delimiter=' +')
    
    bi_star_HIP = pickle.load(open(HIP_path, 'rb'))
    # 读取标准化、提出缺失值后的数据
    data_standard = pickle.load(open(file_path, 'rb'))
    scaler = pickle.load(open(scaler_model, 'rb'))
    
    # 将属于毕星团的数据剔除 
    data_remove_bi = data_standard.loc[~df['HIP'].isin(bi_star_HIP)]
    data_with_bi = data_standard.loc[df['HIP'].isin(bi_star_HIP)]
    
    # 参数网格筛选
    epsilon_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 1, 1.5, 2, 2.5, 3, 3.5]
    min_samples_list = [2, 3, 4]
    coef_list = [ 1.5, 2, 2.5, 3.0, 3.5, 4,  4.5, 5, 5.5, 6, 6.5, 7, 7.5,\
                8, 8.5, 9, 9.5, 10]
    
    
    # 返回最佳参数，并输出结果
    cluster_targets=['pmRA', 'pmDE']
    
    # 构建临时变量，避免程序影响
    data = data_remove_bi.copy()
    data_bi = data_with_bi.copy()
    best_eps, best_min_samples,\
                best_coef = cluster_param_select(data, \
                                                    data_bi, \
                                                    epsilon_list, \
                                                    min_samples_list, \
                                                    coef_list,
                                                    cluster_targets=cluster_targets,
                                                    alpha=0.3)
    
    # 构建临时变量，避免程序影响
    data = data_remove_bi.copy()
    bi_stream = bi_star_stream(data, df, bi_star_HIP,
                                eps=best_eps,
                                min_samples=best_min_samples, 
                                coef=best_coef, \
                                cluster_targets=['pmRA', 'pmDE']
                                )
    # 保存数据
    pickle.dump(bi_stream, open(r'../附件/bi_stream.pkl', 'wb'))
    # 输出赫罗图
    x = bi_stream['B-V']
    y = bi_stream['Vmag']
    plot_HRdiagram(x, y)
    
